from typing import Sequence

from functools import partial
import numpy as np
from einops import rearrange
import torch
from torch import nn

from models.utils import MLP


class APE(nn.Module):
    """Absolute positional embedding module."""

    def __init__(self, space: int, dim: int, fuzzy: bool = False):
        # assert dim % space == 0

        super().__init__()

        self.space = space
        self.dim = dim
        self.fuzzy = fuzzy

    def reset_parameters(self, init_weights: str):
        _ = init_weights
        pass

    def _sincos_1d(self, grid, dim: int, dtype, device, max_wavelength: int = 10000):
        if dim % 2 == 0:
            padding = None
        else:
            padding = torch.zeros(*grid.shape, 1, dtype=dtype, device=device)
            dim -= 1
        omega = 1.0 / max_wavelength ** (
            torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
        )
        out = grid.unsqueeze(-1) @ omega.unsqueeze(0)
        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)
        emb = torch.cat([emb_sin, emb_cos], dim=-1).float()
        if padding is None:
            return emb
        else:
            return torch.cat([emb, padding], dim=-1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        grid_size = x.shape[1:-1]
        grids = [
            torch.arange(grid_size[i], dtype=x.dtype, device=x.device)
            for i in range(self.space)
        ]
        grid = torch.meshgrid(*grids, indexing="ij")

        if self.fuzzy:
            # fuzzy grid
            grid = grid + torch.randn_like(grid, dtype=x.dtype, device=x.device) * 0.1

        emb = [
            self._sincos_1d(
                grid=grid[i],
                dtype=x.dtype,
                device=x.device,
                dim=self.dim,
                max_wavelength=10000,
            )
            for i in range(self.space)
        ]
        return x + torch.stack(emb, dim=-1).mean(-1)[None]


class RPB(nn.Module):
    """Swinv2 relative position bias (RPB) to learn token distances."""

    def __init__(self, space: int, num_heads: int):
        super().__init__()

        self.space = space
        self.num_heads = num_heads

        self.cpb_mlp = MLP(
            [space, 512, num_heads],
            act_fn=partial(nn.ReLU, inplace=True),
            bias=[True, False],
        )

    def reset_parameters(self, init_weights: str):
        _ = init_weights
        pass  # TODO

    def _get_relative_coords_table(
        self, grid_size: Sequence[int], device: torch.device
    ):
        coords_nd = []
        for w in grid_size:
            coords_nd.append(
                torch.arange(-(w - 1), w, dtype=torch.float32, device=device)
            )

        rpb = torch.stack(torch.meshgrid(*coords_nd, indexing="ij"))
        for i in range(self.space):
            rpb[i] = rpb[i] / (grid_size[i] - 1)
        rpb = rearrange(rpb, "d ... -> ... d").unsqueeze(0)
        # normalize to -8, 8
        rpb = 8 * rpb
        rpb = torch.sign(rpb) * torch.log2(torch.abs(rpb) + 1.0) / np.log2(8)

        # index with distances
        grid = torch.stack(
            torch.meshgrid(*[torch.arange(w) for w in grid_size], indexing="ij")
        )  # (space, wD, wH, wW, wU, wV)
        dists = grid.flatten(1).unsqueeze(-1) - grid.flatten(1).unsqueeze(1)

        for i in range(self.space):
            center = max(np.prod([(2 * w - 1) for w in grid_size[(i + 1) :]]), 1)
            dists[i] = (dists[i] + grid_size[i] - 1) * center

        return rpb, dists.sum(0)

    def forward(self, x: torch.Tensor, grid_size: Sequence[int]) -> torch.Tensor:
        # rpb from swinv2
        sl = x.shape[2]
        rpb, rpb_idx = self._get_relative_coords_table(grid_size, x.device)
        rpb = self.cpb_mlp(rpb).view(-1, self.num_heads)
        rpb = rpb[rpb_idx.flatten()].view(sl, sl, self.num_heads)
        rpb = 16 * torch.sigmoid(rpb)
        return rearrange(rpb, "slx sly h -> h slx sly").unsqueeze(0).contiguous()


class RotaryPE(nn.Module):
    """https://github.com/limefax/rope-nd"""

    def __init__(
        self,
        space: int,
        dim: int,
        base: float = 10_000,
        use_complex: bool = True,
    ):
        super().__init__()

        k_max = dim // (2 * space)
        self.use_complex = use_complex
        assert (
            dim % k_max == 0 and k_max > 1
        ), f"dim ({dim}) not divisible by 2 * len(grid_size) (={2 * space})"
        # tensor of angles to use
        self.theta_ks = 1 / (base ** (torch.arange(k_max) / k_max))

    def _rotations(self, grid_size: Sequence[int]) -> torch.Tensor:
        # create a stack of angles multiplied by position
        angles = torch.cat(
            [
                t.unsqueeze(-1) * self.theta_ks
                for t in torch.meshgrid(
                    [torch.arange(d) for d in grid_size], indexing="ij"
                )
            ],
            dim=-1,
        )
        if self.use_complex:
            # convert to complex number to allow easy rotation
            rotations = torch.polar(torch.ones_like(angles), angles)
        else:
            # use real rotation matrix no complex numbers (for bfloat16)
            rotations = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

        return rotations

    def forward(
        self, x: torch.Tensor, grid_size: Sequence[int], flatten: bool = False
    ) -> torch.Tensor:
        if x.ndim < len(grid_size) + 3:
            flatten = True
            b, heads, _, c = x.shape
            # reshape to grid for angle multiplication
            x = x.view(b, heads, *grid_size, c)

        # broadcast batch and head (correct?)
        rotations = self._rotations(grid_size)[None, None]
        if self.use_complex:
            # convert input into complex numbers to perform rotation
            x = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2))
            pe_x = torch.view_as_real(rotations * x).flatten(-2)
        else:
            # reshape into pairs for re / im parts
            x = x.view(*x.shape[:-1], -1, 2)  # [..., c//2, 2]
            x_re, x_im = x[..., 0], x[..., 1]
            rot_cos = rotations[..., 0]
            rot_sin = rotations[..., 1]
            # apply rotation with real arithmetic
            pe_x_re = x_re * rot_cos - x_im * rot_sin
            pe_x_im = x_re * rot_sin + x_im * rot_cos
            pe_x = torch.stack([pe_x_re, pe_x_im], dim=-1)
            pe_x = pe_x.flatten(-2)

        if flatten:
            pe_x = rearrange(pe_x, "b heads ... c -> b heads (...) c")
        return pe_x
